Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Implementation of a JAX Salt2 model #191

Merged
merged 2 commits into from
Nov 22, 2024

Conversation

jeremykubica
Copy link
Contributor

Create an initial implementation of the SALT2 model using JAX. This will run slower on CPU (~4x on my laptop), but should be able to take advantage of both auto-differentiation and GPU/TPU. Further optimizations are likely possible, but this gets an initial base implementation checked in.

Copy link

github-actions bot commented Nov 21, 2024

Before [291c03c] After [c652a25] Ratio Benchmark (Parameter)
1.65±0.1ms 1.37±0.03ms ~0.83 benchmarks.TimeSuite.time_apply_passbands
4.63±0.04ms 4.79±0.2ms 1.03 benchmarks.TimeSuite.time_evaluate_salt3_passbands
8.72±0.3ms 8.93±0.8ms 1.02 benchmarks.TimeSuite.time_evaluate_salt3_model
703±20μs 710±800μs 1.01 benchmarks.TimeSuite.time_fnu_to_flam
122±1μs 123±2μs 1.01 benchmarks.TimeSuite.time_sample_x0_from_distmod
15.9±0.2μs 16.0±0.1μs 1.01 benchmarks.TimeSuite.time_sample_x1_from_hostmass
4.42±0.3ms 4.44±1ms 1.00 benchmarks.TimeSuite.time_chained_evaluate
8.99±0.08ms 9.00±0.1ms 1.00 benchmarks.TimeSuite.time_load_passbands
28.3±0.4μs 28.3±0.2μs 1.00 benchmarks.TimeSuite.time_make_new_salt3_model
1.37±0.01s 1.35±0.02s 0.99 benchmarks.TimeSuite.time_make_x1_from_hostmass

Click here to view all benchmarks.

Copy link
Contributor

@hombit hombit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. When I specify wrong model directory for SALT2JaxModel I'm getting InconsistentTableError instead of FileNotFoundError raised by SALT2Source
  2. When trying to run .evaluate / .get_band_fluxes with array parameters for x0, x1, c I've got broadcasting error. Actually it looks like we have similar problems across the code, for example for passbands. So it could go to a new issue.

@jeremykubica
Copy link
Contributor Author

I fixed the first bug (adding a FileNotFoundError). The second one is a known limitation. I have a mental todo list items to better support batch evaluations. I will add a github issue for that.

@jeremykubica jeremykubica merged commit 97f783f into main Nov 22, 2024
5 checks passed
@jeremykubica jeremykubica deleted the jax_salt2_implementation branch December 2, 2024 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants